#!/usr/bin/env python
"""Namelist creator for CIME's driver.
"""
# Typically ignore this.
# pylint: disable=invalid-name

# Disable these because this is our standard setup
# pylint: disable=wildcard-import,unused-wildcard-import,wrong-import-position

import os, sys, glob, itertools, re

_CIMEROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..","..","..","..")
sys.path.append(os.path.join(_CIMEROOT, "scripts", "Tools"))

from standard_script_setup import *
from CIME.case import Case
from CIME.nmlgen import NamelistGenerator
from CIME.utils import expect, safe_copy
from CIME.utils import get_model, get_time_in_seconds, get_timestamp
from CIME.buildnml import create_namelist_infile, parse_input
from CIME.XML.files import Files
from CIME.XML.grids import Grids

logger = logging.getLogger(__name__)

###############################################################################
def _create_drv_namelists(case, infile, confdir, nmlgen, files):
###############################################################################

    #--------------------------------
    # Set up config dictionary
    #--------------------------------
    config = {}
    config['cime_model'] = get_model()
    config['iyear'] = case.get_value('COMPSET').split('_')[0]
    config['BGC_MODE'] = case.get_value("CCSM_BGC")
    config['CPL_I2O_PER_CAT'] = case.get_value('CPL_I2O_PER_CAT')
    config['COMP_RUN_BARRIERS'] = case.get_value('COMP_RUN_BARRIERS')
    config['DRV_THREADING'] = case.get_value('DRV_THREADING')
    config['CPL_ALBAV'] = case.get_value('CPL_ALBAV')
    config['CPL_EPBAL'] = case.get_value('CPL_EPBAL')
    config['FLDS_WISO'] = case.get_value('FLDS_WISO')
    config['BUDGETS'] = case.get_value('BUDGETS')
    config['MACH'] = case.get_value('MACH')
    config['MPILIB'] = case.get_value('MPILIB')
    config['MULTI_DRIVER'] = '.true.' if case.get_value('MULTI_DRIVER') else '.false.'
    config['OS'] = case.get_value('OS')
    config['glc_nec'] = 0 if case.get_value('GLC_NEC') == 0 else case.get_value('GLC_NEC')
    config['single_column'] = 'true' if case.get_value('PTS_MODE') else 'false'
    config['timer_level'] = 'pos' if case.get_value('TIMER_LEVEL') >= 1 else 'neg'
    config['bfbflag'] = 'on' if case.get_value('BFBFLAG') else 'off'
    config['continue_run'] = '.true.' if case.get_value('CONTINUE_RUN') else '.false.'
    config['atm_grid'] = case.get_value('ATM_GRID')
    config['compocn'] = case.get_value('COMP_OCN')

    docn_mode = case.get_value("DOCN_MODE")
    if docn_mode and 'aqua' in docn_mode:
        config['aqua_planet_sst_type'] = docn_mode
    else:
        config['aqua_planet_sst_type'] = 'none'

    if case.get_value('RUN_TYPE') == 'startup':
        config['run_type'] = 'startup'
    elif case.get_value('RUN_TYPE') == 'hybrid':
        config['run_type'] = 'startup'
    elif case.get_value('RUN_TYPE') == 'branch':
        config['run_type'] = 'branch'

    #----------------------------------------------------
    # Initialize namelist defaults
    #----------------------------------------------------
    nmlgen.init_defaults(infile, config)

    #--------------------------------
    # Overwrite: set brnch_retain_casename
    #--------------------------------
    start_type = nmlgen.get_value('start_type')
    if start_type != 'startup':
        if case.get_value('CASE') == case.get_value('RUN_REFCASE'):
            nmlgen.set_value('brnch_retain_casename' , value='.true.')

    #--------------------------------
    # Overwrite: set component coupling frequencies
    #--------------------------------
    ncpl_base_period  = case.get_value('NCPL_BASE_PERIOD')
    if ncpl_base_period == 'hour':
        basedt = 3600
    elif ncpl_base_period == 'day':
        basedt = 3600 * 24
    elif ncpl_base_period == 'year':
        if case.get_value('CALENDAR') == 'NO_LEAP':
            basedt = 3600 * 24 * 365
        else:
            expect(False, "Invalid CALENDAR for NCPL_BASE_PERIOD {} ".format(ncpl_base_period))
    elif ncpl_base_period == 'decade':
        if case.get_value('CALENDAR') == 'NO_LEAP':
            basedt = 3600 * 24 * 365 * 10
        else:
            expect(False, "invalid NCPL_BASE_PERIOD NCPL_BASE_PERIOD {} ".format(ncpl_base_period))
    else:
        expect(False, "invalid NCPL_BASE_PERIOD NCPL_BASE_PERIOD {} ".format(ncpl_base_period))

    if basedt < 0:
        expect(False, "basedt invalid overflow for NCPL_BASE_PERIOD {} ".format(ncpl_base_period))

    comps = case.get_values("COMP_CLASSES")
    mindt = basedt
    for comp in comps:
        ncpl = case.get_value(comp.upper() + '_NCPL')
        if ncpl is not None:
            cpl_dt = int(basedt / int(ncpl))
            totaldt = cpl_dt * int(ncpl)
            if totaldt != basedt:
                expect(False, " {} ncpl doesn't divide base dt evenly".format(comp))
            nmlgen.add_default(comp.lower() + '_cpl_dt', value=cpl_dt)
            mindt = min(mindt, cpl_dt)

    # sanity check
    comp_atm = case.get_value("COMP_ATM")
    if comp_atm is not None and comp_atm not in('datm', 'xatm', 'satm'):
        atmdt = int(basedt / case.get_value('ATM_NCPL'))
        expect(atmdt == mindt, 'Active atm should match shortest model timestep atmdt={} mindt={}'
               .format(atmdt, mindt))

    #--------------------------------
    # Overwrite: set start_ymd
    #--------------------------------
    run_startdate = "".join(str(x) for x in case.get_value('RUN_STARTDATE').split('-'))
    nmlgen.set_value('start_ymd', value=run_startdate)

    #--------------------------------
    # Overwrite: set tprof_option and tprof_n - if tprof_total is > 0
    #--------------------------------
    # This would be better handled inside the alarm logic in the driver routines.
    # Here supporting only nday(s), nmonth(s), and nyear(s).

    stop_option = case.get_value('STOP_OPTION')
    if 'nyear' in stop_option:
        tprofoption = 'ndays'
        tprofmult = 365
    elif 'nmonth' in stop_option:
        tprofoption = 'ndays'
        tprofmult = 30
    elif 'nday' in stop_option:
        tprofoption = 'ndays'
        tprofmult = 1
    else:
        tprofmult = 1
        tprofoption = 'never'

    tprof_total = case.get_value('TPROF_TOTAL')
    if ((tprof_total > 0) and (case.get_value('STOP_DATE') < 0) and ('ndays' in tprofoption)):
        stop_n = case.get_value('STOP_N')
        stopn = tprofmult * stop_n
        tprofn = int(stopn / tprof_total)
        if tprofn < 1:
            tprofn = 1
        nmlgen.set_value('tprof_option', value=tprofoption)
        nmlgen.set_value('tprof_n'     , value=tprofn)

    # Set up the pause_component_list if pause is active
    pauseo = case.get_value('PAUSE_OPTION')
    if pauseo is not None and pauseo != 'never' and pauseo != 'none':
        pausen = case.get_value('PAUSE_N')
        # Set esp interval
        if 'nstep' in pauseo:
            esp_time = mindt
        else:
            esp_time = get_time_in_seconds(pausen, pauseo)

        nmlgen.set_value('esp_cpl_dt', value=esp_time)
    # End if pause is active

    #--------------------------------
    # (1) Write output namelist file drv_in and  input dataset list.
    #--------------------------------
    write_drv_in_file(case, nmlgen, confdir)

    #--------------------------------
    # (2) Write out seq_map.rc file
    #--------------------------------
    write_seq_maps_file(case, nmlgen, confdir)

    #--------------------------------
    # (3) Construct and write out drv_flds_in
    #--------------------------------
    write_drv_flds_in_file(case, nmlgen, files)

###############################################################################
def write_drv_in_file(case, nmlgen, confdir):
###############################################################################
    data_list_path = os.path.join(case.get_case_root(), "Buildconf", "cpl.input_data_list")
    if os.path.exists(data_list_path):
        os.remove(data_list_path)
    namelist_file = os.path.join(confdir, "drv_in")
    nmlgen.write_output_file(namelist_file, data_list_path )

###############################################################################
def write_seq_maps_file(case, nmlgen, confdir):
###############################################################################
    # first determine if there are invalid idmap settings
    # if source and destination grid are different, mapping file must not be "idmap"
    gridvalue = {}
    ignore_component = {}
    exclude_list = ["CPL","ESP"]
    for comp_class in case.get_values("COMP_CLASSES"):
        if comp_class not in exclude_list:
            gridvalue[comp_class.lower()] = case.get_value(comp_class + "_GRID" )
            if case.get_value(comp_class + "_GRID" ) == 'null':
                ignore_component[comp_class.lower()] = True
            else:
                ignore_component[comp_class.lower()] = False

    # The ATM grid may contain vertical grid information ("zxx" where xx is a number of levels)
    # it's not desirable here, remove it.
    if 'atm' in gridvalue and 'z' in gridvalue['atm']:
        gridvalue['atm'] = (gridvalue['atm'])[0:gridvalue['atm'].rfind('z')]

    # Currently, hard-wire values of mapping file names to ignore
    # TODO: for rof2ocn_fmapname -needs to be resolved since this is currently
    # used in prep_ocn_mod.F90 if flood_present is True - this is in issue #1908.
    # The following is only approriate for config_grids.xml version 2.0 or later
    grid_version = Grids().get_version()
    if grid_version >= 2.0:
        group_variables = nmlgen.get_group_variables("seq_maps")
        for name in group_variables:
            value = group_variables[name]
            if "mapname" in name:
                value = re.sub('\"', '', value)
                if 'idmap' == value:
                    component1 = name[0:3]
                    component2 = name[4:7]
                    if not ignore_component[component1]  and not ignore_component[component2]:
                        if "rof2ocn_" in name:
                            if case.get_value("COMP_OCN") == 'docn':
                                logger.warning("   NOTE: ignoring setting of {}=idmap in seq_maps.rc".format(name))
                        else:
                            expect(gridvalue[component1] == gridvalue[component2],
                                   "Need to provide valid mapping file between {} and {} in xml variable {} ".\
                                   format(component1, component2, name))

    # now write out the file
    seq_maps_file = os.path.join(confdir, "seq_maps.rc")
    nmlgen.write_seq_maps(seq_maps_file)

###############################################################################
def write_drv_flds_in_file(case, nmlgen, files):
###############################################################################
    # In thte following, all values come simply from the infiles - no default values need to be added
    # FIXME - do want to add the possibility that will use a user definition file for drv_flds_in

    caseroot = case.get_value('CASEROOT')

    nmlgen.add_default('drv_flds_in_files')
    drvflds_files = nmlgen.get_default('drv_flds_in_files')
    infiles = []
    for drvflds_file in drvflds_files:
        infile = os.path.join(caseroot, drvflds_file)
        if os.path.isfile(infile):
            infiles.append(infile)

    if len(infiles) != 0:
        # First read the drv_flds_in files and make sure that
        # for any key there are not two conflicting values
        dicts = {}
        for infile in infiles:
            dict_ = {}
            with open(infile) as myfile:
                for line in myfile:
                    if "=" in line and '!' not in line:
                        name, var = line.partition("=")[::2]
                        name = name.strip()
                        var = var.strip()
                        dict_[name] = var
            dicts[infile] = dict_

        for first,second in itertools.combinations(dicts.keys(),2):
            compare_drv_flds_in(dicts[first], dicts[second], first, second)

        # Now create drv_flds_in
        config = {}
        definition_dir = os.path.dirname(files.get_value("NAMELIST_DEFINITION_FILE", attribute={"component":"drv"}))
        definition_file = [os.path.join(definition_dir, "namelist_definition_drv_flds.xml")]
        nmlgen = NamelistGenerator(case, definition_file, files=files)
        skip_entry_loop = True
        nmlgen.init_defaults(infiles, config, skip_entry_loop=skip_entry_loop)
        drv_flds_in = os.path.join(caseroot, "CaseDocs", "drv_flds_in")
        nmlgen.write_output_file(drv_flds_in)

###############################################################################
def compare_drv_flds_in(first, second, infile1, infile2):
###############################################################################
    sharedKeys = set(first.keys()).intersection(second.keys())
    for key in sharedKeys:
        if first[key] != second[key]:
            print('Key: {}, \n Value 1: {}, \n Value 2: {}'.format(key, first[key], second[key]))
            expect(False, "incompatible settings in drv_flds_in from \n {} \n and \n {}".format(infile1, infile2))

###############################################################################
def _create_component_modelio_namelists(case, files):
###############################################################################

    # will need to create a new namelist generator
    infiles = []
    definition_dir = os.path.dirname(files.get_value("NAMELIST_DEFINITION_FILE", attribute={"component":"drv"}))
    definition_file = [os.path.join(definition_dir, "namelist_definition_modelio.xml")]

    confdir = os.path.join(case.get_value("CASEBUILD"), "cplconf")
    lid = os.environ["LID"] if "LID" in os.environ else get_timestamp("%y%m%d-%H%M%S")

    #if we are in multi-coupler mode the number of instances of cpl will be the max
    # of any NINST_* value
    maxinst = 1
    if case.get_value("MULTI_DRIVER"):
        maxinst = case.get_value("NINST_MAX")

    for model in case.get_values("COMP_CLASSES"):
        model = model.lower()
        with NamelistGenerator(case, definition_file) as nmlgen:
            config = {}
            config['component'] = model
            entries = nmlgen.init_defaults(infiles, config, skip_entry_loop=True)
            if maxinst == 1 and model != 'cpl':
                inst_count = case.get_value("NINST_" + model.upper())
            else:
                inst_count = maxinst

            inst_string = ""
            inst_index = 1
            while inst_index <= inst_count:
                # determine instance string
                if inst_count > 1:
                    inst_string = '_{:04d}'.format(inst_index)

                # set default values
                for entry in entries:
                    nmlgen.add_default(entry)

                # overwrite defaults
                moddiri = case.get_value('EXEROOT') + "/" + model
                nmlgen.set_value('diri', moddiri)

                moddiro = case.get_value('RUNDIR')
                nmlgen.set_value('diro', moddiro)

                logfile = model + inst_string + ".log." + str(lid)
                nmlgen.set_value('logfile', logfile)

                # Write output file
                modelio_file = model + "_modelio.nml" + inst_string
                nmlgen.write_modelio_file(os.path.join(confdir, modelio_file))

                inst_index = inst_index + 1

###############################################################################
def buildnml(case, caseroot, component):
###############################################################################
    if component != "drv":
        raise AttributeError

    confdir = os.path.join(case.get_value("CASEBUILD"), "cplconf")
    if not os.path.isdir(confdir):
        os.makedirs(confdir)

    # NOTE: User definition *replaces* existing definition.
    # TODO: Append instead of replace?
    user_xml_dir = os.path.join(caseroot, "SourceMods", "src.drv")

    expect (os.path.isdir(user_xml_dir),
            "user_xml_dir {} does not exist ".format(user_xml_dir))

    files = Files(comp_interface="mct")
    definition_file = [files.get_value("NAMELIST_DEFINITION_FILE", {"component": "drv"})]

    user_definition = os.path.join(user_xml_dir, "namelist_definition_drv.xml")
    if os.path.isfile(user_definition):
        definition_file = [user_definition]

    # Create the namelist generator object - independent of instance
    nmlgen = NamelistGenerator(case, definition_file)

    # create cplconf/namelist
    infile_text = ""
    if case.get_value('COMP_ATM') == 'cam':
        # cam is actually changing the driver namelist settings
        cam_config_opts = case.get_value("CAM_CONFIG_OPTS")
        if "aquaplanet" in cam_config_opts:
            infile_text = "aqua_planet = .true. \n aqua_planet_sst = 1"

    user_nl_file = os.path.join(caseroot, "user_nl_cpl")
    namelist_infile = os.path.join(confdir, "namelist_infile")
    create_namelist_infile(case, user_nl_file, namelist_infile, infile_text)
    infile = [namelist_infile]

    # create the files drv_in, drv_flds_in and seq_maps.rc
    _create_drv_namelists(case, infile, confdir, nmlgen, files)

    # create the files comp_modelio.nml where comp = [atm, lnd...]
    _create_component_modelio_namelists(case, files)

    # copy drv_in, drv_flds_in, seq_maps.rc and all *modio* fiels to rundir
    rundir = case.get_value("RUNDIR")

    safe_copy(os.path.join(confdir,"drv_in"), rundir)
    drv_flds_in = os.path.join(caseroot, "CaseDocs", "drv_flds_in")
    if os.path.isfile(drv_flds_in):
        safe_copy(drv_flds_in, rundir)

    safe_copy(os.path.join(confdir,"seq_maps.rc"), rundir)

    for filename in glob.glob(os.path.join(confdir, "*modelio*")):
        safe_copy(filename, rundir)

###############################################################################
def _main_func():
    caseroot = parse_input(sys.argv)

    with Case(caseroot) as case:
        buildnml(case, caseroot, "drv")

if __name__ == "__main__":
    _main_func()
